{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# EfficientDet Training On A Custom Dataset\n",
    "\n",
    "\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "  <a target=\"_blank\"  href=\"https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/blob/master/tutorial/train_shape.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on github\n",
    "  </a>\n",
    "</td><td>\n",
    "  <a target=\"_blank\"  href=\"https://colab.research.google.com/github/zylo117/Yet-Another-EfficientDet-Pytorch/blob/master/tutorial/train_shape.ipynb\">\n",
    "    <img width=32px src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "</td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## This tutorial will show you how to train a custom dataset.\n",
    "\n",
    "## For the sake of simplicity, I generated a dataset of different shapes, like rectangles, triangles, circles.\n",
    "\n",
    "## Please enable GPU support to accelerate on notebook setting if you are using colab.\n",
    "\n",
    "### 0. Install Requirements"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install pycocotools numpy==1.16.0 opencv-python tqdm tensorboard tensorboardX pyyaml matplotlib\n",
    "!pip install torch==1.4.0\n",
    "!pip install torchvision==0.5.0"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 1. Prepare Custom Dataset/Pretrained Weights (Skip this part if you already have datasets and weights of your own)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "if \"projects\" not in os.getcwd():\n",
    "  !git clone --depth 1 https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch\n",
    "  os.chdir('Yet-Another-EfficientDet-Pytorch')\n",
    "  sys.path.append('.')\n",
    "else:\n",
    "  !git pull\n",
    "\n",
    "# download and unzip dataset\n",
    "! mkdir datasets\n",
    "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.1/dataset_shape.tar.gz\n",
    "! tar xzf dataset_shape.tar.gz\n",
    "\n",
    "# download pretrained weights\n",
    "! mkdir weights\n",
    "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.0/efficientdet-d0.pth -O weights/efficientdet-d0.pth\n",
    "\n",
    "# prepare project file projects/shape.yml\n",
    "# showing its contents here\n",
    "! cat projects/shape.yml"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 2. Training"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# consider this is a simple dataset, train head will be enough.\n",
    "! python train.py -c 0 -p shape --head_only True --lr 1e-3 --batch_size 32 --load_weights weights/efficientdet-d0.pth  --num_epochs 50\n",
    "\n",
    "# the loss will be high at first\n",
    "# don't panic, be patient,\n",
    "# just wait for a little bit longer"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 3. Evaluation"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "! python coco_eval.py -c 0 -p shape -w logs/shape/efficientdet-d0_49_1400.pth"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 4. Visualize"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de7RdVX3o8e9vrrX3PnkgCSRQTECCQAW9ijRCLARQSgX0Ch3FB/WWaNEoYC/Utha9HX1wvR311lFbexWIAsY7rIq2NrnUq+UhJvYK8VB5KYJBwyMDSUCeyck5e635u3/Mufde56xzcp775fl9xtjZa8219l7zZO/123PNNR+iqhhjTJHrdgaMMb3HAoMxpsQCgzGmxAKDMabEAoMxpsQCgzGmpC2BQUTOFpEHRWS7iFzZjmMYY9pH5rodg4gkwEPAWcDjwPeBC1X1R3N6IGNM27SjxHASsF1Vf6qqI8CXgfPacBxjTJukbXjPFcBjhfXHgZP394Jly5bpkUce2YasGGMa7rrrrqdUdflU9m1HYJgSEVkPrAc44ogjGBwc7FZWjJkXROSRqe7bjkuJncDhhfWVMW0UVd2gqqtVdfXy5VMKYsaYDmlHYPg+cIyIrBKRKvBOYHMbjmOMaZM5v5RQ1UxEPgh8C0iA61X1h3N9HGNM+7SljkFVvwF8ox3vbYxpP2v5aIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpmTQwiMj1IrJLRO4vpB0kIjeLyE/i89KYLiLyKRHZLiL3isiJ7cy8MaY9plJi+Dxw9pi0K4FbVfUY4Na4DnAOcEx8rAeunptsGmM6adLAoKpbgF+MST4P2BiXNwLnF9K/oMEdwBIROWyuMmuM6YyZ1jEcqqpPxOWfA4fG5RXAY4X9Ho9pJSKyXkQGRWRw9+7dM8yGMaYdZl35qKoK6Axet0FVV6vq6uXLl882G8aYOTTTwPBk4xIhPu+K6TuBwwv7rYxpxpg+MtPAsBlYF5fXAZsK6RfFuxNrgOcKlxzGmD6RTraDiHwJOANYJiKPA38O/DVwo4hcDDwCvD3u/g3gXGA7sBd4TxvybIxps0kDg6peOMGmM8fZV4HLZpspY0x3WctHY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMyaSBQUQOF5Fvi8iPROSHInJ5TD9IRG4WkZ/E56UxXUTkUyKyXUTuFZET2/1HGGPm1lRKDBnwh6p6PLAGuExEjgeuBG5V1WOAW+M6wDnAMfGxHrh6znNtjGmrSQODqj6hqv8Rl18AHgBWAOcBG+NuG4Hz4/J5wBc0uANYIiKHzXnOjTFtM606BhE5EngtcCdwqKo+ETf9HDg0Lq8AHiu87PGYZozpE1MODCKyGPgn4ApVfb64TVUV0OkcWETWi8igiAzu3r17Oi81xrTZlAKDiFQIQeGLqvrPMfnJxiVCfN4V03cChxdevjKmjaKqG1R1taquXr58+Uzzb4xpg6nclRDgOuABVf3bwqbNwLq4vA7YVEi/KN6dWAM8V7jkMMb0gXQK+5wC/C5wn4jcHdM+Cvw1cKOIXAw8Arw9bvsGcC6wHdgLvGdOc2yMabtJA4OqfheQCTafOc7+Clw2y3wZY7rIWj4aY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMPQtBQ1P/cx3OwNmXGm3M2BmbtgDCk5oBYhioJDC8nTS5+I9iukyerMWVjIgVajYN7Gn2MfRtxy1pNt5mL1KY0EVZLzIZLph0sAgIgPAFqAW9/+aqv65iKwCvgwcDNwF/K6qjohIDfgC8GvA08A7VHVHm/I/7510+rtBW1eExWvDYjF925brWXPa75XSx9t/oveYaXozTTxOwQvgE7ZtvS78DaddwLbvwOgihgWHbppKHcMw8EZVfQ1wAnC2iKwBPg58UlWPBp4BLo77Xww8E9M/Gfczc8gXTppt3/k8ojlChpChhYcUHkApfduW61Ey7ojP27ZcP+o97ojrY997ovSxxyylqUdVw3MhhGzb8rXRlyCm6yYNDBq8GFcr8aHAG4GvxfSNwPlx+by4Ttx+pojYxz5HxlbWvX7tOpwod275Ak6l+Ri7DoxaL6a9fu26cfcZ7zXF/Wf2cM33bTjtjHe0+X/NTNeU7kqISCIidwO7gJuBh4FnVTWLuzwOrIjLK4DHAOL25wiXG2Pfc72IDIrI4O7du2f3V8wjjjEfmjq8CiefdhEe13ycvPY9o9aBUevjpU223nh8b+vGcdOL28bb53tbw++FF4+XVqlny+1fGXPlYJcR3TalwKCquaqeAKwETgJeMdsDq+oGVV2tqquXL18+27ebZ1q/tr5ZXyeE8sT4j5PXrps0rbwPE77XnVtv2O9xJjyeeNAEJG/+Daec8c4+uZQIAcuTt/6HtPGPtnZRgJzmBVNcR2n9d2hv36qdVjsGVX0W+DbwemCJiDQqL1cCO+PyTuBwgLj9QEIlpGkDp4DM/S/snVs3xpN7fPvbtl/q4uVEX0SCMQR8q9TmIJ7sEgNz1twNHC6XZmWr1yR8Ti4+pLcbEU2aNxFZLiJL4vIC4CzgAUKAuCDutg7YFJc3x3Xi9ttU1cqGfWbGJ/4vM6V50jd//YUYncEjeNFm5bBPwEsG5CAaSxqxyNDjjdOm0o7hMGCjiCSEQHKjqt4kIj8CviwiHwN+AFwX978O+N8ish34BfDONuTbmM5r3EVVCcsS7rkoQiLgNBnVSCzEizScNR6QtHDrttOZn55JA4Oq3gu8dpz0nxLqG8am7wPeNie5M6ZjpnimjtktnWibjFl3o55moLPFC2v5aEzTZCdfuH7YV89wlZwqi6jzGIKSI3j2ABk5i3GMIKSkHAQ4KrwktO6ERlUFAG7SxlxSeO5ccLDAYExJ8SQsLjuGeIwFlcP5qwd+nY8eB//jkXMgGSHxipOFDGcjJBVInEdHBD+0kGq+kIF9v82lq3dSYRlOazhpNPEa29FkKh1S2s8CQ78ZcxfCS6j0QrSna7mBcG2uAi4Lty0b2nBXZVZUS3l6ke1cPXg+I4t+Bx14DX7xLwD4i5fdz1WPvQIVyH1GmgAevAo4QRYNUWeIfNFzfOLH72Dg+Zfy/pP+hsW8DIeSZ54k1XgL15MxTlsVoNNBoue/S2YyHofixIeWhT5pPYotDqeTPhfvUUwvtHp0MTi4Qv8O8b3UGyx25lJBs3AC7mEH19z3Xv549b+QL3qKP111T7w9CVc9elwr4BWDXaFlJz7hT1fdR33Bs/zhSV/lmgcugthQ3aXgfRI7kAkpMkEbBy082s9KDP1GwDfbJYLzCcTbZJnLS/tO9B5TTp/L9xCPczle/KiOXyo91NTHCzl1nPPsSx9lw+C72HvYW8iWplz1yCtBfAgG0HyejJOcjz12LLic//74cejiXfzV9tdRee63uPTXPsMi91JyRvBkVMhwAp5hCvdHO85KDH3IFT4273Iy5/EuB3RUI+Tir8x00ufiPYrpo37tGiWJwi+q9NgN/YSUjD38w0MX8sKye8jqyoiOhIA2A43/BUf4X8E5RgaeY++yB/n0g2/H8xwJlXDcWHpyWpu7P2gGrMTQl1on1Z1bN+5nv/6g5JPv1CniyeRF/mHb77Hn0GcRUirJCBVfCaWcGQQH7/IQzPNK8xg4hzp4YcHTfPqOy7l0zf8i8Ysb0aPrLDD0nVi89Mopb3gH6itIXkEFsl4qko9HYiUbAJ7EeYa9MJBq8xe12/bI83z2nnfz/LLtJCSAUkeoeIfKDPs3NEpHLms0ksTFTi6qGU+/dBufGLyQK1bfQE2XoeLx6uhmzYv0Qmvl1atX6+DgYLez0Rda979zhjUBZfQXqMeHdmsMVRnr2hjJoSqQJt27ng7CrcK/ueN3GPmVe0lTZVgzyKuhhaPLSXWG+RMN7aMbdUAqzToWyYW620PFH8DAo6v5g1M2kFYcXj1OEsb/UGaYDZG7VHX1VPa1EkOfaf2qJtTGtq7rQ2nzG9jJPyQ0VMpwpIDmgiQwzCMMHfojcJ7cA7RO5hkHBWjdsSi+RyzdaeJIXQVNRhg66nvUKzug/lK8LqBanexvgOb/2/6C8wyy3gulN2M6TECT+KsozfPmM1v+ApKRDmfFg6+AT/Dq+NxdH0Uqbv+D4zb6a6i06nVFQbLYnT2nuWGG8cwCg5mXwriTgkfJJXSX3nfIQ53NhGg4kVVweHwOzw08yjA7EbefS4dSn4w8LGga6nA0aQWOGbLAYOadMIBKY0yEHIlnwd4Fz3Y2IyqhtJCMgMtJq1Bf+CLXbvsQ+X7qFFSV3OeF6ockvl1GsxLHE4aHGAb2xcc0WB2DmXdG/xomrQZjkpe2tp34Zq9sXI5XYWThU4ywnQUcO+5LVDVUTCooikiYoEOylC233YvPhcSlgAOfkCQJIyPTu0SywGDmIWneNiQX8iTUN6TNInnnKkK9KM5X8Aguh0olI0scQwyzYILXOHF4r4gTJBP8MPy/7zyErztgIXhP3StprNmt5xnVyvQaTFlgMPOXZJBAxgPUgEq+kMzV23a4PzvigSnu+WPg1WPSRk/n5Rr3qNPQ3+LUc2efvyILDGZ+EoA0jLIWrx7aGRT6jQUGM78JEJtkq8br9Ta66tHjmiWHxnKjM9bY5dwPk7jGJUCrQcJ3//UhsnpoF5EkrVN47VtfztbND5eWG+vTYXclzPzUGNJdQlMnCIGh3Ron/3g9M4tB4+P3vxHvssLWYjfulDSpUqlU25ZnKzGYeSz0hWj0SnCuh34nfUJSOD23bvpZaAwlvjlEXDtLOBYYzDyWx5GTOtdfaOylxITbj3iAwlzgrD1vVXN566afsfa8VaMuFQC2bn64eckw0SXFVFlgMPOTQKv7WWdPg2JAmCg41J55Gac88UNOPTukbf3XH4fbqHnoQDHRyT5esJiJHio7GdMFCkKs4NMunw6NCX99gqsPoL6Qnzz2p3B57IA1zlBvc5h/CwxmXiqOq5DEOZdTSToaHEbNECoKeWiUlIws5PiD3wDZwJhXSGHkKyk8GklzNx6HBQYzL7lGnwKXkbIMgGEdiUPkdSoP8VkdDo/Wcuo+w9WXUNv5qx3Lx3isjsHMUxL7RqTNE1TxiHZ23CSnLs6InaAjVRKpkTz9K1RePKKj+Sjlq6tHN6bbmnNMwgD7HRmlDcd2o8aR1HQvfgiOX3oqyfBEPSU6wwKDmZ8UINyq1FiBt/DR1aSdHIlVPLgs1C+4jNQv4ICnjmbg4dPZ/0gt7WeBwcw7njhQCwrkzTGcLln7caRU4dd+zif4vAo+4TXpf6HqhJF8uOP5GJWnrh7dmC5oTgGnAqTNiv0ay6ntOoaqr5AgaBIqIp3kZJrgtXDXQsNYBx7wk9zJ8C4Pg8ECCWH6uwxHhiNJElwCiSqLd5zOwJPH4nwVke7OzmWBwcxfpdbEwgfXfJbKU6sYqUOqA7i8iseRuizcyYDQUtLVIRnBST7pbUKX1XCx30NdE7x4quRUNTRrrmf7GNixhtfIhZAoKkp1f0O7dYAFBmMKqhzIpas/waJsKblkca4OT5Kn4Oqjb2fGOxiTnUROfLx0AZDQWMnlgCfzntreg/lPA79N7cVD0DxBcWi9u8N/W2AwZowFejR/8Ks3seD5pXjnqZOQuzyO5pyGywdfDfUC0Lq7MAEvHi+NYeihgifLa4yIY+D5Q/jj479JbfcqMs1RVdLUo5UJ364jphwYRCQRkR+IyE1xfZWI3Cki20XkKyJSjem1uL49bj+yPVk3Zu4571GUhCVcetwNLHryKJKsSpZDPRtGPOR1QZxvDTXf6PU4gUql0pyXwgN5kpFmVV7y7NH811fdgBs+kGpSwavgU08997i8u7OKTafEcDlQHJvq48AnVfVo4Bng4ph+MfBMTP9k3M+YvqAeRD2Cp6Iv54rXbeSQX7yCBdUBBgaESq1OxaWtXplTaBA1tHeY5rVEqiS6iCW/eCWXv/oGFrKKRByJpqS+QhZnLpcu99uY0tFFZCXwZuBzcV2ANwJfi7tsBM6Py+fFdeL2M6Xdw+IYM0ckdXiXoChVHLV8KZec8EUuX7GZdPup6ItLkGQvmsWZpRrTz+3nRE4HErI0BJCFPzuBPzr8Jj74uo1UdGkYIkbjxLduhAU+9IfochXDlEsMfwd8mFafj4OBZ1W1McTM48CKuLwCeAwgbn8u7j+KiKwXkUERGdy9e/cMs2/M3Au3M124a+FAkoSFrOSK0z/NFcdu4oAdJ7No6AAWiIM8xacjSJKhIuRewkkiaQgWmbB4eDEHPHwSAB9a+xlqrAAEL1kcth58Hk4ujWNDuC5PUDxp8yoReQuwS1XvEpEz5urAqroB2ABhUtu5el9j5oaMegKoMUCqjkvWXkPOEBlPc82Wq0gH4te3sheqe5Chg/G5oFrnvWs+huNAakcdDHwexwG0pvb28Rao4CQhx+Mbc1wWmkp3w1TaXZ4CvFVEzgUGgJcAfw8sEZE0lgpWAjvj/juBw4HHRSQFDgSenvOcGzPnJi+/JzJ6dvEPnzbZK74+Zt01y+nN4noVfv3NE7/DTAdbmY1JA4OqfgT4CEAsMfyRqr5LRL4KXAB8GVgHbIov2RzXvxe336adGGXTmFmZ/Cvqm/sJmkEYoLkxe5WQZRlpHLXZq8c5CZWOLieEEyHMG9cICYL3Hpc5/v2W7fis0Kqyy5cSs6n6/BPgQyKynVCHcF1Mvw44OKZ/CLhydlk0pjc4bZwwOdL8SQ0TyHrApSleGtPNOTw5YRCmpDUgC+CbJZMc58LFRK/9dk6rC5eq3g7cHpd/Cpw0zj77gLfNQd6M6S0KoT7AhZXGaEqiuEZpQNM4uGwdtBoqFltTQoQRqeP7hJ4TYaq5UQPSNodu696tCRuoxZipiiOp+eI6ECoREyAtnM+NmgjFyaidwyti705RqNdBNQwa0yt6JyfG9Lrmr/7YX/JkzD5SSBNK9RdSuIYXIW2MDyMefBpmwFYp9K/oPOsrYcxckvGWJzrDQ7o4WHtmHEY+3qrsZlAACwzGdFejMJEQ+1z0RiWkBQZjuqh5MyIFXDZph6xOscBgTBdJaAIBoqz9jeNRqZMkSdtn3Z6MBQZjeoJABZKKouotMBgzn/lGJYOCSkbuM+r1etcbPFlg6JbCtIMebX5BitOWTec9ptKk1/Qep9Js+yCknHbWq9DKSOiUAYCS1X3o2t3B+gcLDF2SSY4XxWv4cjiVGCCAGCgyQsv6ECxi8FCaffgR8JKThTmUuva3mFloTj8ZorwKvOGsV4HLEIE8z6nVaqFi0qfh0QEWGLrE4WJACBOaeslwKqQaGtCER3H/kOZFSZV4W0vJcaTxfUw/anxwAnhyMqjAqWceTZIkJFKjng3T6gbemRKDtXxsu/F/yRsDdIxuBVfeZ2zkdmP2q0x8iMgiRm+TQmxISFNtJo+wB5ccQOKrqDYuIzrTh8JKDL/0ilOlj30eu5/pOB27HJtQi7L2zOPxMkSmPk7A27mOVVZi6IhicTFWMubKi074/c99lZ35Um75wFkAnHzNVvb5Ybz31Go1sizj7kvO5LVX3w7ikaSGjOQMa8ZRvMDmD57HPp8x4JJR7986XvH445Uexu5jOqp4njcqkyV8jq4Ga3/zWL777ftheEHYuUNjNVhg6Ihw/Qi+WXn4FMLvf/pf2J6+lB984FROvPpW8sQhPkNclQTP4PtOH/UeXhMWZDlDicfh2MGBADxMwiHA8v0ev/TTNGZ5nM4+pkMalw+Fy4r4ebgqrD3zVXz3mw+17kx0IDhYYOgYT6ahP/6ODN73uZt5NlkKcTxdJQkD/QhxDHNCKaFJcUAdxWmYijWJt7QuvvY2lrt9/J/3hz3reU4lcaNeOz4LBr1Bxl0MPJImrD37WL5zy33oyACpJGSZkqTSagylCY4wuc2c5KjbDSkgDAY7ODjY7Wy0ieDxuFwYTuBRhbd95hbEzXaqofANciguAXLPYj/M1z5wNksbwwo2Ws+pWhVCT5hJIA77q5cwUn0G37vlp6Au3s4sliCKpb/R77H2rUcjwl2qunoqR7XKxzZqxG6HJ3NhjP311/4bqQww0zO10R3XKZB7PIK4ChkJz7OID3x2E3vHvrUFhf7UrIwUcq2TkyMVOOVNR1FPn8G7HJ/JmLsVc1OxbIGhjUQbrRkT9gKXX/stnskHyLXOXBThnQsf33A9Ixclr8AjHMgjwNCoSq36rI9luqDwGaZJJXSucoqmntPfdAKnnPVyfHUPKgI+RaRxaRHGkBQJg83OhAWGDhAPF13zTX42lCJpBm7m/+3NhkyieBnVJprcj5CgvPuz/86TozJgH3Nf08IDQdQhCJLAGW9+Jae9ZRVrz3kZVOuIA685xT5YOoNRX6zysY3Eh0kIhgQedgup1sBl9Thy8MyEIb90VH8KFyskRSqMkDCU7ePyDTexaX3YnpHYB93PineUFbzXMDQ9NEsFruY45axjkQzI4fbb7wV15OpJ0+nXZ9n3pQN+LJD4jAVZjaFa7Bfh8xm/n1Mhi0WH1pDmkGceJ0JShR060Np/Npk3vaMRIJJQiek13LZ2zuE1D+1cqg6vcNrZr8RJGJF6JvcXLDC0k4bf9T+55pvU3CKG0ix8gLPoCOOlFRCaaYSTP02UDMVlFZKkFXgsMPSS2dcENz5PV3iriZYRmMnQDhYY2mjYKTVgDwn7CHMPOG1P24EwkSq4XMApvvBtaE6VaLqs200Dph4h7PvSRpX4OezF4fN6OHnb1GrNC2HqdOepKmje+hJYz0szXRYY2shJnFvAJ1TTVuGsHY1ZnYKoCw2vVdARu0VpZs4CQxsNx+eKCPVcC0ODz/1/uwOqGkoOmVNcpXAM7f6ow6a/WGBoo8bpOJwNU1GHxkuIVOf+P74xJJzD4xSKNzR9j8xVYPqHBYY2alzbp2lK3dVJfBLbIfi2XE548aAJPslIkgWtDZJM/CJjxmGBoY1q8dkTSvPOC67Nv94+zsAs2czbSRhjgaEDqkkVcmIzZt++/3R1IfCoIHkrMPRAB1rTZywwtFE91ikkXllYTaknOcNCOIHbcDwvoT7B+YSs0roLklhgMNNkgaGNGuM8JzpMlmUkCok6MqStw707IHeFSwnrdm2myQJDG6V5+KleVXkR0irOJ6Rt/fX2pD50sqr4QjsGCwxmmqYUGERkh4jcJyJ3i8hgTDtIRG4WkZ/E56UxXUTkUyKyXUTuFZET2/kH9LYQBf7yveeT1+vURag7z4C2pzWig+ZdjwO6PPeh6W/TKTG8QVVPKAwNdSVwq6oeA9wa1wHOAY6Jj/XA1XOV2b4Tx104DFgqdbxmqBeGJYsVkbHI7/I4Vl+jL8XEGl3rU8I8A7nzOAmPTIWRxJOrcmx9qPWaNv155pfXbC4lzgM2xuWNwPmF9C9ocAewREQOm8Vx+lZGuM6vAZ+77E0cJDm1TKkmVTJ1LMgctTxMO5b6BC9KPdn/bcY0nuaZpnhNEJ/gNcGrQ0ZykhHlKF7kY5f85+ZrrK+Ema6pBgYF/k1E7hKROPwHh6rqE3H558ChcXkF8FjhtY/HtFFEZL2IDIrI4O7du2eQ9d7XOIkdyvIcjuRZlIysnuNcxlAaGjo5JY6v4EnQ/Xay8upIgVy02QzaSeiLn9YqLK4Oce36t3JwsU2TFRnMNE01MJyqqicSLhMuE5HTihs1DDU9rd8lVd2gqqtVdfXy5RPPiNDP8kKLwwWJ8plLLuAo2cOw1MPMQq4eBvR0ORrbNzifhB6YE/AIHqiRg+SMNO4+iKJ+hOvffw5LkjEfrFU3mGmaUmBQ1Z3xeRfwdeAk4MnGJUJ83hV33wkcXnj5ypg27ySN/14VKngWq+dLl53PcbxAmkOaVck1RUmp+RAUMnXs70xOFTLxDDtAKkie4LTKUq/ceMlvcLRCtT76g53NUHJmfpr0KyMii0TkgMYy8JvA/cBmYF3cbR2wKS5vBi6KdyfWAM8VLjnmmVYhypOQiZCiXP/+3+J49vAStxfv64iHDBd6X8r+C1+ZU5w6Kt4xkCtp6lmZvMAN7zuTVXEUcamA7/qgIKafTWUEp0OBr0u4/ZUC/6iq3xSR7wM3isjFwCPA2+P+3wDOBbYDe4H3zHmu+4VKHOpfWjNXC7wkgY2XtvfQbtRy5yZDNb8cemImKhF5AXiw2/mYomXAU93OxBT0Sz6hf/LaL/mE8fP6MlWdUoVer4z5+OBUp87qNhEZ7Ie89ks+oX/y2i/5hNnn1aqljDElFhiMMSW9Ehg2dDsD09Avee2XfEL/5LVf8gmzzGtPVD4aY3pLr5QYjDE9pOuBQUTOFpEHYzftKyd/RVvzcr2I7BKR+wtpPdm9XEQOF5Fvi8iPROSHInJ5L+ZXRAZEZJuI3BPz+ZcxfZWI3Bnz8xURqcb0WlzfHrcf2Yl8FvKbiMgPROSmHs9ne4dCUNWuPYAEeBg4CqgC9wDHdzE/pwEnAvcX0v4ncGVcvhL4eFw+F/i/hJZDa4A7O5zXw4AT4/IBwEPA8b2W33i8xXG5AtwZj38j8M6Yfg1wSVy+FLgmLr8T+EqH/18/BPwjcFNc79V87qgn+38AAAIxSURBVACWjUmbs8++Y3/IBH/c64FvFdY/Anyky3k6ckxgeBA4LC4fRmhzAXAtcOF4+3Up35uAs3o5v8BC4D+AkwmNb9Kx3wPgW8Dr43Ia95MO5W8lYWyRNwI3xROp5/IZjzleYJizz77blxJT6qLdZbPqXt4JsRj7WsKvcc/lNxbP7yZ0tLuZUEp8VlWzcfLSzGfc/hxwcCfyCfwd8GFaHdUP7tF8QhuGQijqlZaPfUFVVaS3pnUSkcXAPwFXqOrzUhjSrVfyq6o5cIKILCH0zn1Fl7NUIiJvAXap6l0icka38zMFp6rqThE5BLhZRH5c3Djbz77bJYZ+6KLds93LRaRCCApfVNV/jsk9m19VfRb4NqFIvkREGj9Mxbw08xm3Hwg83YHsnQK8VUR2AF8mXE78fQ/mE2j/UAjdDgzfB46JNb9VQiXO5i7naaye7F4uoWhwHfCAqv5tr+ZXRJbHkgIisoBQD/IAIUBcMEE+G/m/ALhN44VxO6nqR1R1paoeSfge3qaq7+q1fEKHhkLoVGXJfipRziXUqD8M/Lcu5+VLwBNAnXAddjHhuvFW4CfALcBBcV8BPh3zfR+wusN5PZVwnXkvcHd8nNtr+QVeDfwg5vN+4M9i+lHANkL3/K8CtZg+ENe3x+1HdeF7cAatuxI9l8+Yp3vi44eN82YuP3tr+WiMKen2pYQxpgdZYDDGlFhgMMaUWGAwxpRYYDDGlFhgMMaUWGAwxpRYYDDGlPx/iALj2+UtxycAAAAASUVORK5CYII=\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.backends import cudnn\n",
    "\n",
    "from backbone import EfficientDetBackbone\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from efficientdet.utils import BBoxTransform, ClipBoxes\n",
    "from utils.utils import preprocess, invert_affine, postprocess\n",
    "\n",
    "compound_coef = 0\n",
    "force_input_size = None  # set None to use default size\n",
    "img_path = 'datasets/shape/val/999.jpg'\n",
    "\n",
    "threshold = 0.2\n",
    "iou_threshold = 0.2\n",
    "\n",
    "use_cuda = True\n",
    "use_float16 = False\n",
    "cudnn.fastest = True\n",
    "cudnn.benchmark = True\n",
    "\n",
    "obj_list = ['rectangle', 'circle']\n",
    "\n",
    "# tf bilinear interpolation is different from any other's, just make do\n",
    "input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]\n",
    "input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size\n",
    "ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)\n",
    "\n",
    "if use_cuda:\n",
    "    x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)\n",
    "else:\n",
    "    x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)\n",
    "\n",
    "x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)\n",
    "\n",
    "model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),\n",
    "\n",
    "                             # replace this part with your project's anchor config\n",
    "                             ratios=[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)],\n",
    "                             scales=[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])\n",
    "\n",
    "model.load_state_dict(torch.load('logs/shape/efficientdet-d0_49_1400.pth'))\n",
    "model.requires_grad_(False)\n",
    "model.eval()\n",
    "\n",
    "if use_cuda:\n",
    "    model = model.cuda()\n",
    "if use_float16:\n",
    "    model = model.half()\n",
    "\n",
    "with torch.no_grad():\n",
    "    features, regression, classification, anchors = model(x)\n",
    "\n",
    "    regressBoxes = BBoxTransform()\n",
    "    clipBoxes = ClipBoxes()\n",
    "\n",
    "    out = postprocess(x,\n",
    "                      anchors, regression, classification,\n",
    "                      regressBoxes, clipBoxes,\n",
    "                      threshold, iou_threshold)\n",
    "\n",
    "out = invert_affine(framed_metas, out)\n",
    "\n",
    "for i in range(len(ori_imgs)):\n",
    "    if len(out[i]['rois']) == 0:\n",
    "        continue\n",
    "\n",
    "    for j in range(len(out[i]['rois'])):\n",
    "        (x1, y1, x2, y2) = out[i]['rois'][j].astype(np.int)\n",
    "        cv2.rectangle(ori_imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)\n",
    "        obj = obj_list[out[i]['class_ids'][j]]\n",
    "        score = float(out[i]['scores'][j])\n",
    "\n",
    "        cv2.putText(ori_imgs[i], '{}, {:.3f}'.format(obj, score),\n",
    "                    (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,\n",
    "                    (255, 255, 0), 1)\n",
    "\n",
    "        plt.imshow(ori_imgs[i])\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}